#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <tuple>
#include <sys/cdefs.h>
#include "esmd_types.h"
#ifdef EXT_LINCS
const int nrec = 4;

template<int N, class ...Ts>
struct int_tuple{
  typedef typename int_tuple<N - 1, int, Ts...>::type type;
};
template<class ...Ts>
struct int_tuple<0, Ts...>{
  typedef std::tuple<Ts...> type;
};
template<int NATOMS, int ...IJS>
struct constraint_type{
  static constexpr int ijs[sizeof...(IJS)] = {IJS...};
  static constexpr int ncons = sizeof...(IJS) / 2;
  typedef typename int_tuple<NATOMS>::type idx_type;
  static constexpr real sign(int ic, int jc) {
    int i0 = ijs[ic * 2];
    int j0 = ijs[ic * 2 + 1];
    int i1 = ijs[jc * 2];
    int j1 = ijs[jc * 2 + 1];
    if (i0 == i1 || j0 == j1) return -1;
    else return 1;
  }

  static constexpr int conn(int ic, int jc) {
    int i0 = ijs[ic * 2];
    int j0 = ijs[ic * 2 + 1];
    int i1 = ijs[jc * 2];
    int j1 = ijs[jc * 2 + 1];
    if (i0 == i1) return i0;
    if (i0 == j1) return i0;
    if (j0 == i1) return j0;
    if (j0 == j1) return j0;
  }
};

template<int ...Is>
struct int_seq{
  
};
template<int N, int ...Is>
struct GenSeq {
  typedef typename GenSeq<N-1, N-1, Is...>::seq seq;
};
template<int ...Is>
struct GenSeq<0, Is...> {
  typedef int_seq<Is...> seq;
};
template<int N, int ...Is>
struct GenSeqRev {
  typedef typename GenSeqRev<N-1, Is..., N - 1>::seq seq;
};
template<int ...Is>
struct GenSeqRev<0, Is...> {
  typedef int_seq<Is...> seq;
};
#include <string.h>
template<class T>
struct lincs{
  real A[T::ncons][T::ncons];
  vec<real> B[T::ncons];
  real Sdiag[T::ncons];
  real rhs[2][T::ncons];
  real sol[T::ncons];
  typedef typename T::idx_type idx_type;
  template<int IC>
  __always_inline void do_init_rhs(real *rhs, vec<real> *x, vec<real> *xp, real *invmass, real *r0, const idx_type &idx) {
    int i = std::get<T::ijs[IC * 2]>(idx);
    int j = std::get<T::ijs[IC * 2 + 1]>(idx);
    B[IC] = x[i] - x[j];
    B[IC] /= B[IC].norm();
    Sdiag[IC] = 1.0 / sqrt(invmass[i] + invmass[j]);
    rhs[IC] = Sdiag[IC] * (B[IC].dot(xp[i] - xp[j]) - r0[IC]);
    sol[IC] = rhs[IC];
  }
  template<int IC>
  __always_inline void do_init_rhs_esmd(real *rhs, real *alt_sdiag, vec<real> *x, vec<real> *xp, real *invmass, real *r0, const idx_type &idx) {
    int i = std::get<T::ijs[IC * 2]>(idx);
    int j = std::get<T::ijs[IC * 2 + 1]>(idx);
    B[IC] = x[i] - x[j];
    B[IC] /= B[IC].norm();
    Sdiag[IC] = alt_sdiag[IC];
    rhs[IC] = Sdiag[IC] * (B[IC].dot(xp[i] - xp[j]) - r0[IC]);

      // printf("%d %d %d %f %f %f %f\n", i, j, IC, x[i].x, B[IC].x, Sdiag[IC], rhs[IC]);

    // printf("%f\n", rhs[IC]);
    sol[IC] = rhs[IC];
  }
  template<int IJC>
  __always_inline void do_init_mat(real *invmass, const idx_type &idx) {
    constexpr int ic = IJC / T::ncons;
    constexpr int jc = IJC % T::ncons;
    constexpr int conn = T::conn(ic, jc);
    int c = std::get<conn>(idx);
    if (ic == jc){
      A[ic][jc] = 0;
    } else {
      A[ic][jc] = T::sign(ic, jc) * invmass[c] * Sdiag[ic] * Sdiag[jc] * B[ic].dot(B[jc]);
    }
  }

  template<int IJC>
  __always_inline void do_mm(real *rhs_next, real *rhs_cur){
    constexpr int ic = IJC / T::ncons;
    constexpr int jc = IJC % T::ncons;
    if (jc == 0) {
      rhs_next[ic] = 0;
    }
    rhs_next[ic] += A[ic][jc] * rhs_cur[jc];
    if (jc == T::ncons - 1) {
      sol[ic] += rhs_next[ic];
    }
  }

  template<int IC>
  __always_inline void do_update(vec<real> *xp, real *invmass, const idx_type &idx) {
    int i = std::get<T::ijs[IC * 2]>(idx);
    int j = std::get<T::ijs[IC * 2 + 1]>(idx);
    xp[i] -= B[IC] * invmass[i] * Sdiag[IC] * sol[IC];
    xp[j] += B[IC] * invmass[j] * Sdiag[IC] * sol[IC];
  }

  template<int IC>
  __always_inline void do_update_esmd(vec<real> *f, real rdtfsq, const idx_type &idx) {
    int i = std::get<T::ijs[IC * 2]>(idx);
    int j = std::get<T::ijs[IC * 2 + 1]>(idx);
    f[i] -= B[IC] * Sdiag[IC] * sol[IC] * rdtfsq;
    f[j] += B[IC] * Sdiag[IC] * sol[IC] * rdtfsq;
  }

  template<int IC>
  __always_inline void do_correct(real *rhs, vec<real> *xp, real *r0, const idx_type &idx) {
    int i = std::get<T::ijs[IC * 2]>(idx);
    int j = std::get<T::ijs[IC * 2 + 1]>(idx);
    real p = sqrt(2 * r0[IC] * r0[IC] - (xp[i] - xp[j]).norm2());
    rhs[IC] = Sdiag[IC] * (r0[IC] - p);
    sol[IC] = rhs[IC];
  }

  template<int I, int D>
  const __always_inline int gidx(const idx_type &idx) {
    return std::get<T::ijs[I*2+D]>(idx);
  }
  template<int ...Is, int ...Js>
  __always_inline void do_lincs_clean(vec<real> *x, vec<real> *xp, real *invmass, real *r0, const idx_type &idx){
    // __always_inline vec<real> get_B(int i, int j) {
    //   return (x[i] - x[j]).unit_vec();
    // }
    // __always_inline vec<real> get_Sdiag(int i, int j) {
    //   return (x[i] - x[j])
    // }
    // B = {(x[gidx<Is,0>(idx)] - x[gidx<Is,1>(idx)]).unit_vec()...};
    // Sdiag = {(1.0 / sqrt(invmass[gidx<Is, 0>(idx)] + invmass[gidx<Is, 1>(idx)]))...};
    // rhs = {(Sdiga[Is] * (B[Is].dot(xp[gidx<Is,0>(idx)] - xp[gidx<Is,1>(idx)])))...};
    // sol = {rhs[Is]...};
  }
#define RUNALL(x) (int[]){(x, 0)...}
  template<int ...Is, int ...IJs>
  __always_inline void do_lincs(int_seq<Is...> seq1d, int_seq<IJs...> seq2d, vec<real> *x, vec<real> *xp, real *invmass, real *r0, const idx_type &idx) {
    // puts(__PRETTY_FUNCTION__);
    // (int[]){(do_init_rhs<Is>(rhs[0], x, xp, invmass, r0, idx), 0)...};
    RUNALL(do_init_rhs<Is>(rhs[0], x, xp, invmass, r0, idx));
    printf("rhs:%f %f %f\n", rhs[0][0], rhs[0][1], rhs[0][2]);

    RUNALL(do_init_mat<IJs>(invmass, idx));
    puts("A:");
    printf("%f %f %f\n%f %f %f\n%f %f %f\n", A[0][0], A[0][1], A[0][2], A[1][0], A[1][1], A[1][2], A[2][0], A[2][1], A[2][2]);
    for (int i = 0; i < nrec; i ++) {
      RUNALL(do_mm<IJs>(rhs[i&1^1], rhs[i&1]));
    }
    RUNALL(do_update<Is>(xp, invmass, idx));
    RUNALL(do_correct<Is>(rhs[0], xp, r0, idx));
    for (int i = 0; i < nrec; i ++) {
      RUNALL(do_mm<IJs>(rhs[i&1^1], rhs[i&1]));
    }
    RUNALL(do_update<Is>(xp, invmass, idx));
  }
  template<int ...Is, int ...IJs>
  __always_inline void do_lincs_esmd(int_seq<Is...> seq1d, int_seq<IJs...> seq2d, real *alt_sdiag, real rdtfsq, vec<real> *x, vec<real> *xp, vec<real> *f, real *invmass, real *r0, const idx_type &idx) {
    // puts(__PRETTY_FUNCTION__);
    // (int[]){(do_init_rhs<Is>(rhs[0], x, xp, invmass, r0, idx), 0)...};
    RUNALL(do_init_rhs_esmd<Is>(rhs[0], alt_sdiag, x, xp, invmass, r0, idx));
    vec<real> B[] = {(x[gidx<Is,0>(idx)] - x[gidx<Is,1>(idx)]).unit_vec()...};
    real Sdiag[] = {(1.0 / sqrt(invmass[gidx<Is, 0>(idx)] + invmass[gidx<Is, 1>(idx)]))...};
    real rhs[] = {(Sdiag[Is] * (B[Is].dot(xp[gidx<Is,0>(idx)] - xp[gidx<Is,1>(idx)]) - r0[Is]))...};
    real sol[] = {rhs[Is]...};
    memcpy(this->sol, sol, sizeof(sol));
    for (int i = 0; i < 2; i ++) {
      printf("rhis: %f %f\n", Sdiag[i], this->Sdiag[i]);
    }
    RUNALL(do_init_mat<IJs>(invmass, idx));
    puts("A:");
    printf("%f %f %f\n%f %f %f\n%f %f %f\n", A[0][0], A[0][1], A[0][2], A[1][0], A[1][1], A[1][2], A[2][0], A[2][1], A[2][2]);
    for (int i = 0; i < nrec; i ++) {
      RUNALL(do_mm<IJs>(this->rhs[i&1^1], this->rhs[i&1]));
    }
    RUNALL(do_update<Is>(xp, invmass, idx));
    RUNALL(do_update_esmd<Is>(f, rdtfsq, idx));
    RUNALL(do_correct<Is>(this->rhs[0], xp, r0, idx));
    for (int i = 0; i < nrec; i ++) {
      RUNALL(do_mm<IJs>(this->rhs[i&1^1], this->rhs[i&1]));
    }
    RUNALL(do_update<Is>(xp, invmass, idx));
    RUNALL(do_update_esmd<Is>(f, rdtfsq, idx));
  }
#undef RUNALL
  void run(vec<real> *x, vec<real> *xp, real *invmass, real *r0, idx_type idx) {
    typename GenSeq<T::ncons>::seq seq1d;
    typename GenSeq<T::ncons*T::ncons>::seq seq2d;
    do_lincs(seq1d, seq2d, x, xp, invmass, r0, idx);
  }
  void run_esmd(vec<real> *x, vec<real> *xp, vec<real> *f, real *Sdiag, real *invmass, real *r0, idx_type idx, real rdtfsq) {
    typename GenSeq<T::ncons>::seq seq1d;
    typename GenSeq<T::ncons*T::ncons>::seq seq2d;
    do_lincs_esmd(seq1d, seq2d, Sdiag, rdtfsq, x, xp, f, invmass, r0, idx);
    // do_lincs_esmd(seq1d, seq2d, Sdiag, rdtfsq, x, xp, f, invmass, r0, idx);
  }
};
#else
#define INLINE static inline
#include "shake.hpp"
#endif
// const int nrec = 40;
int main(int argc, char **argv){
  const int N = 3;
  vec<real> x[N] = {{1.891914, 1.718993, 6.916843},
                    {2.325009, 2.464732, 6.505912},
                    {1.786481, 2.042712, 7.812702}};
  vec<real> xp[N] = {{1.896918, 1.712022, 6.907338},
                    {2.349682, 2.449896, 6.494517},
                    {1.754187, 2.075783, 7.812295}};
  vec<real> f[5];
  real invmass[N] = {0.06250234383789392, 0.992063, 0.992063};
  real alt_sdiag[N] = {0.9482596425211665, 0.9482596425211665, 0.504};
  real r0[N] = {0.957200, 0.957200, 1.513901};
  for (int i = 0; i < N; i ++) printf("%f %f %f\n", xp[i].x, xp[i].y, xp[i].z);
  for (int i = 0; i < N; i ++) {
    for (int j = 0; j < N; j ++) {
      printf("%f ", (xp[i] - xp[j]).norm());
    }
    puts("");
  }
  #ifdef EXT_LINCS
  shake<constraint_type<3, 0, 1, 0, 2, 1, 2>> l;
  l.run(x, xp, invmass, r0, std::make_tuple(0, 1, 2));
  #else
  shake l = {4, 200, 1e-6};
  rigid_rec rec;
  int idx[3];
  idx[0] = 0;
  idx[1] = 1;
  idx[2] = 2;
  rec.param.shake.rsq[0] = r0[0] * r0[0];
  rec.param.shake.rsq[1] = r0[1] * r0[1];
  rec.param.shake.rsq[2] = r0[2] * r0[2];
  // l.run_static(x, xp, invmass, rec, idx);
  l.run_static<rigid3angle>(x, xp, invmass, rec, idx);
  l.run_static<rigid3angle>(x, xp, invmass, rec, idx);
  l.run_static<rigid3angle>(x, xp, invmass, rec, idx);
  #endif
  // l.run_esmd(x, xp, f, alt_sdiag, invmass, r0, std::make_tuple(0, 1, 2), 400);
  // printf("%f %f\n%f %f\n", l.A[0][0], l.A[0][1], l.A[1][0], l.A[1][1]);
  for (int i = 0; i < N; i ++) printf("%f %f %f\n", xp[i].x, xp[i].y, xp[i].z);
  for (int i = 0; i < N; i ++) {
    for (int j = 0; j < N; j ++) {
      printf("%f ", (xp[i] - xp[j]).norm());
    }
    puts("");
  }
}
/*
x: 1.891914 1.718993 6.916843
x: 2.325009 2.464732 6.505912
x: 1.786481 2.042712 7.812702
1.891914 1.718993 6.916843
localf: 1 5 3 26
26 28 29
1.891914 1.718993 6.916843
2.325009 2.464732 6.505912
1.786481 2.042712 7.812702
1.896918 1.712022 6.907338
2.349682 2.449896 6.494517
1.754187 2.075783 7.812295
0.955281 0.958370 1.475064
0.957200 0.957200 1.513901
1.899491 1.716478 6.904937
2.294065 2.367913 6.568999
1.768962 2.087040 7.775935
1.894751 1.711796 6.914914
2.401289 2.511716 6.402960
1.736973 2.017558 7.783604
0.832412 0.955506 1.345851
0.957200 0.957200 1.513901
*/